from pyspark.ml.regression import LinearRegression, LinearRegressionModel
from utils.spark_util import col_validation, one_hot_encoding, feature_cols_labelling, \
    one_hot_encoding_p, get_model_info, cast_to_float, label_col_labelling


def predict(sc, dataframe, para_list, output):
    print("LR PRED BEGIN")
    # --------------------prepare feature data------------------
    num_cols, cat_cols = col_validation(dataframe, para_list["feature_col"])
    print(num_cols, cat_cols)

    model_save_path, feature_list_train = get_model_info(sc, para_list["model_id"])

    cat_col_dict = {}

    i = -1
    for col in feature_list_train:
        if "_OneHot/" in col:
            feature, category = col.split("_OneHot/")
            if feature in cat_col_dict.keys():
                cat_col_dict[cat_cols[i]].append(category)
            else:
                i += 1
                cat_col_dict[cat_cols[i]] = [category]

    if len(cat_cols) != 0:
        dataframe, message, encoded_cat_cols = one_hot_encoding_p(sc, dataframe, para_list["source"], cat_cols
                                                                  , cat_col_dict)
        if len(message) != 0:
            output["message"].append(message)
    else:
        encoded_cat_cols = []

    feature_list = num_cols + encoded_cat_cols
    print("FEATURE LIST")
    df = feature_cols_labelling(dataframe, feature_list)

    # --------------------load the model-----------------

    model = LinearRegressionModel.load(model_save_path)
    print("MODEL LOADED")
    # --------------------make predictions----------------

    df_pred = model.transform(df)
    df_pred = cast_to_float(df_pred, "prediction").drop("features").drop("label") \
        .drop("rawPrediction").drop("probability")
    print("DF_PRED DONE")
    new_col_name_temp = "_prediction_"
    new_col_name = new_col_name_temp
    count = 1
    while new_col_name in dataframe.columns:
        count += 1
        new_col_name = new_col_name_temp + "_" + str(count)
    df_pred = df_pred.withColumnRenamed("prediction", new_col_name)
    print("DF_PRED NEW COLUMN")
    output_cols = []
    for feature in para_list["feature_col"]:
        output_cols.append(feature)
    output_cols.append(new_col_name)
    output["result"]["output_params"]["output_cols"] = output_cols
    if len(encoded_cat_cols) != 0:
        for col in encoded_cat_cols:
            df_pred = df_pred.drop(col)

    df_pred.show(3)
    return df_pred


def train(sc, dataframe, para_list, record, output):
    print("START TRAINING LR SPARK")
    # --------------------prepare data------------------
    num_cols, cat_cols = col_validation(dataframe, para_list["feature_col"])
    print("COLUMNS VALIDATED")
    print(num_cols, cat_cols)
    if len(cat_cols) != 0:
        dataframe, message, encoded_cat_cols = one_hot_encoding(sc, dataframe, para_list["source"], cat_cols)
        if len(message) != 0:
            output["message"].append(message)
    else:
        encoded_cat_cols = []

    feature_list = num_cols + encoded_cat_cols
    print("FEATURE LIST")
    df = feature_cols_labelling(dataframe, feature_list)
    df = label_col_labelling(df, para_list["label_col"])

    # train, test = train_test_split(df, para_list["split_rate"])
    # ---------------------model fit---------------------
    print("MODEL FITTING")
    lr = LinearRegression(regParam=para_list["reg_param"], maxIter=para_list["max_iter"])

    model = lr.fit(df)
    print("LR MODEL FITTED")
    # ---------------------model validation---------------

    # test_pred = model.transform(test)

    # ---------------------get metrics--------------------
    more_info = {}
    train_summary = model.summary

    more_info["total_iterations"] = train_summary.totalIterations
    more_info["df"] = train_summary.degreesOfFreedom
    # more_info["coefficient_standard_errors"] = train_summary.coefficientStandardErrors

    RMSE = train_summary.rootMeanSquaredError
    MSE = train_summary.meanSquaredError
    MAE = train_summary.meanAbsoluteError
    r2 = train_summary.r2
    r2adj = train_summary.r2adj
    pvalues = train_summary.pValues
    tvalues = train_summary.tValues
    print("METRICS CALCULATED")
    # ---------------------record info----------------------
    other_info = {
        "feature_list": feature_list,
        "feature_col": para_list["feature_col"],
        "coefficients": list(model.coefficients.toArray()),
        "intercept": model.intercept,
        "MSE": MSE,
        "MAE": MAE,
        "RMSE": RMSE,
        "r2": r2,
        "r2adj": r2adj,
        "pvalues": list(pvalues),
        "tvalues": list(tvalues),
        "total_iterations": train_summary.totalIterations,
        "df": train_summary.degreesOfFreedom,
        "coefficient_standard_errors": train_summary.coefficientStandardErrors,
    }
    record["other_info"] = other_info

    df_pred = model.transform(df)
    print("PROCESSING DF_PRED")
    df_pred = cast_to_float(df_pred, "prediction").drop("features").drop("label") \
        .drop("rawPrediction").drop("probability")
    print("RETURN")
    return model, df_pred
